{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NOTE:  THIS NOTEBOOK WILL TAKE A 5-10 MINUTES TO COMPLETE.\n",
    "\n",
    "# PLEASE BE PATIENT."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Serving a PyTorch Model as a REST Endpoint with TorchServe and SageMaker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will deploy our BERT PyTorch Model as a REST Endpoint on SageMaker using TorchServe https://github.com/pytorch/serve/\n",
    "\n",
    "TorchServe can be used for many types of inference in production settings. It provides an easy-to-use command line interface and utilizes REST based APIs handle state prediction requests.\n",
    "\n",
    "<img src=\"./img/torchserve.png\" width=\"90%\">\n",
    "  \n",
    "More information on how to deploy Huggingface Transformers with TorchServe:\n",
    "* https://github.com/pytorch/serve/tree/master/examples/Huggingface_Transformers\n",
    "* https://medium.com/analytics-vidhya/deploy-huggingface-s-bert-to-production-with-pytorch-serve-27b068026d18 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import pandas as pd\n",
    "\n",
    "sess   = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name\n",
    "account_id = boto3.client('sts').get_caller_identity().get('Account')\n",
    "\n",
    "sm = boto3.Session().client(service_name='sagemaker', region_name=region)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PRE-REQUISITE: \n",
    "\n",
    "## You need to have succesfully run the notebooks in the `TRAINING` section and converted your TF model into PyTorch before you continue with this notebook. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r training_job_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    training_job_name\n",
    "    print('[OK]')\n",
    "except NameError:\n",
    "    print('+++++++++++++++++++++++++++++++')\n",
    "    print('[ERROR] Please run the notebooks in the previous TRAIN section before you continue.')\n",
    "    print('+++++++++++++++++++++++++++++++')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(training_job_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r transformer_pytorch_model_s3_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    transformer_pytorch_model_s3_uri\n",
    "    print('[OK]')\n",
    "except NameError:\n",
    "    print('+++++++++++++++++++++++++++++++')\n",
    "    print('[ERROR] Please run the notebooks in the previous TRAIN section before you continue.')\n",
    "    print('+++++++++++++++++++++++++++++++')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(transformer_pytorch_model_s3_uri)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Copy the Transformer PyTorch Model from S3 to Local"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "local_model_dir = './models/transformers/pytorch/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!aws s3 cp --recursive $transformer_pytorch_model_s3_uri $local_model_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Retrieve Transformer PyTorch Model Name (.bin) Created During Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r transformer_pytorch_model_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    transformer_pytorch_model_name\n",
    "    print('[OK]')\n",
    "except NameError:\n",
    "    print('+++++++++++++++++++++++++++++++')\n",
    "    print('[ERROR] Please run the notebooks in the previous TRAIN section before you continue.')\n",
    "    print('+++++++++++++++++++++++++++++++')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(transformer_pytorch_model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create TorchServe Model Archive File (.mar)\n",
    "\n",
    "https://github.com/pytorch/serve/blob/master/model-archiver/README.md\n",
    "\n",
    "A key feature of TorchServe is the ability to package all model artifacts into a single model archive file. It is a separate command line interface (CLI), torch-model-archiver, that can take model checkpoints or model definition file with state_dict, and package them into a .mar file. This file can then be redistributed and served by anyone using TorchServe. It takes in the following model artifacts: a model checkpoint file in case of torchscript or a model definition file and a state_dict file in case of eager mode, and other optional assets that may be required to serve the model. The CLI creates a .mar file that TorchServe's server CLI uses to serve the models. \n",
    "\n",
    "We need to pass the the following:\n",
    "* `--handler`:  Python code to adapt the `review_body` to BERT tokens (request handler) as well as the `star_rating` response of 1-5 (response handler)\n",
    "* `config.json`:  used by the Huggingface transformers library when we saved the model in a previous notebook.  In \n",
    "* `setup_config.json`:  BERT-specific `setup_config.json` that defines the `max seq length`, `number of output classes` (1-5), etc.\n",
    "* `Seq_classification_artifacts/index_to_name.json`:  BERT-specific mapping of response index (0-4) to class name (1-5 star rating) for our output classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p ./model_store"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!torch-model-archiver -f \\\n",
    "    --model-name model \\\n",
    "    --export-path ./model_store/ \\\n",
    "    --version 1.0 \\\n",
    "    --serialized-file $local_model_dir/$transformer_pytorch_model_name \\\n",
    "    --handler ./src_torchserve/Transformer_handler_generalized.py \\\n",
    "    --extra-files \"./models/transformers/pytorch/config.json,./src_torchserve/setup_config.json,./src_torchserve/Seq_classification_artifacts/index_to_name.json\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls -al ./model_store/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Start TorchServe locally to serve the model\n",
    "\n",
    "After you archive and store the model, use the torchserve command to serve the model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare the Model for SageMaker Deployment\n",
    "\n",
    "To deploy the model to a SageMaker REST endpoint, we need to upload our .mar file to S3 and build a TorchServe model container. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p ./tmp/model_mar/\n",
    "!unzip -o ./model_store/model.mar ./tmp/model_mar/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Upload TorchServe Model Archive File to S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torchserve_mar = 'model.mar'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tar the `.mar` Archive File as `model.tar.gz` and Upload to S3\n",
    "Per TorchServe convention, the `.mar` file must be under ./model_store/ in the `.tar` archive"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!tar -cvzf ./tmp/model.tar.gz \\\n",
    "    ./model_store/$torchserve_mar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp_torchserve_model_name = 'reviews-distilbert-pytorch'\n",
    "\n",
    "print(tmp_torchserve_model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "tmp_torchserve_tar_s3_uri = 's3://{}/models/torchserve/model.tar.gz'.format(bucket, tmp_torchserve_model_name)\n",
    "\n",
    "print(tmp_torchserve_tar_s3_uri)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Upload `model.tar.gz` to S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 cp ./tmp/model.tar.gz $tmp_torchserve_tar_s3_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tmp_torchserve_tar_s3_uri)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 ls $tmp_torchserve_tar_s3_uri"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build a TorchServe Docker Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pygmentize ./docker/Dockerfile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "docker_repo = 'torchserve'\n",
    "docker_tag = 'torch-1.5.0-1.0.0'\n",
    "\n",
    "image_uri = f'{account_id}.dkr.ecr.{region}.amazonaws.com/{docker_repo}:{docker_tag}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!docker build -t $docker_repo:$docker_tag -f ./docker/Dockerfile ./docker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check the Docker Image\n",
    "If the image did not build properly, re-run the cell above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!docker inspect $docker_repo:$docker_tag"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Push the Image to a Private Docker Repo (Amazon ECR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "account_id = boto3.client('sts').get_caller_identity().get('Account')\n",
    "region = boto3.session.Session().region_name\n",
    "\n",
    "image_uri = '{}.dkr.ecr.{}.amazonaws.com/{}:{}'.format(account_id, region, docker_repo, docker_tag)\n",
    "print(image_uri)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!$(aws ecr get-login --region $region --registry-ids $account_id --no-include-email)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ignore the `RepositoryNotFoundException` Error Below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws ecr describe-repositories --repository-names $docker_repo || aws ecr create-repository --repository-name $docker_repo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!docker tag $docker_repo:$docker_tag $image_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!docker push $image_uri"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create SageMaker Endpoint and Deploy TorchServe Model Container"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "timestamp = int(time.time())\n",
    "\n",
    "pytorch_model_name = '{}-{}-{}'.format(training_job_name, 'pt', timestamp)\n",
    "\n",
    "print(pytorch_model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.model import Model\n",
    "from sagemaker.predictor import Predictor\n",
    "\n",
    "pytorch_model = Model(image_uri=image_uri,\n",
    "                      model_data=tmp_torchserve_tar_s3_uri,                       \n",
    "                      role=role,\n",
    "                      predictor_cls=Predictor,\n",
    "                      name=pytorch_model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store pytorch_model_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_endpoint_name = '{}-{}-{}'.format(training_job_name, 'pt', timestamp)\n",
    "\n",
    "print(pytorch_endpoint_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor = pytorch_model.deploy(instance_type='ml.m5.large',\n",
    "                                 initial_instance_count=1,\n",
    "                                 endpoint_name=pytorch_endpoint_name,\n",
    "                                 wait=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a target=\"blank\" href=\"https://console.aws.amazon.com/sagemaker/home?region={}#/endpoints/{}\">REST Endpoint</a></b>'.format(region, pytorch_endpoint_name)))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# _Wait Until the Endpoint is Deployed_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "waiter = sm.get_waiter('endpoint_in_service')\n",
    "waiter.wait(EndpointName=pytorch_endpoint_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# _Wait Until the ^^ Endpoint ^^ is Deployed_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Waiting for the Endpoint to be ready to Serve Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "time.sleep(30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predict the `star_rating` with Ad Hoc `review_body` Samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicted_classes = predictor.predict('This is great!')\n",
    "\n",
    "print(predicted_classes.decode('utf-8'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predict the `star_rating` with `review_body` Samples from our TSV's"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "\n",
    "df_reviews = pd.read_csv('./data/amazon_reviews_us_Digital_Software_v1_00.tsv.gz', \n",
    "                         delimiter='\\t', \n",
    "                         quoting=csv.QUOTE_NONE,\n",
    "                         compression='gzip')\n",
    "df_sample_reviews = df_reviews[['review_body', 'star_rating']].sample(n=50)\n",
    "df_sample_reviews = df_sample_reviews.reset_index()\n",
    "df_sample_reviews.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "def predict(review_body):\n",
    "    return predictor.predict(review_body).decode('utf-8')\n",
    "\n",
    "df_sample_reviews['predicted_class'] = df_sample_reviews['review_body'].map(predict)\n",
    "df_sample_reviews.head(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save for Next Notebook(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store pytorch_endpoint_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Delete Endpoint\n",
    "To save money, we should delete the endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sm.delete_endpoint(\n",
    "#     EndpointName=pytorch_endpoint_name\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%javascript\n",
    "Jupyter.notebook.save_checkpoint();\n",
    "Jupyter.notebook.session.delete();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
